"""
Tensorflow 2.x
"""
import tensorflow as tf
import numpy as np
from python_ai.common.xcommon import *
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

SHUFFLE_BUFFER = 2000
BATCH_SIZE = 18000

ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))\
    .shuffle(SHUFFLE_BUFFER)\
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)\
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

for i, (bx, by) in enumerate(ds):
    sep(i)
    bx = bx.numpy()
    by = by.numpy()
    print_numpy_ndarray_info(bx, 'bx')
    print_numpy_ndarray_info(by, 'by')
